test_that("simnple tracing works", {
  fn <- function(x) {
    torch_relu(x)
  }

  input <- torch_tensor(c(-1, 0, 1))
  tr_fn <- jit_trace(fn, input)

  expect_equal_to_tensor(tr_fn(input), fn(input))
})

test_that("print the graph works", {
  fn <- function(x) {
    torch_relu(x)
  }

  input <- torch_tensor(c(-1, 0, 1))
  tr_fn <- jit_trace(fn, input)

  expect_output(print(tr_fn$graph), regexp = "graph")
})

test_that("modules are equivalent", {
  Net <- nn_module(
    "Net",
    initialize = function() {
      self$conv1 <- nn_conv2d(1, 32, 3, 1)
      self$conv2 <- nn_conv2d(32, 64, 3, 1)
      self$dropout1 <- nn_dropout2d(0.25)
      self$dropout2 <- nn_dropout2d(0.5)
      self$fc1 <- nn_linear(9216, 128)
      self$fc2 <- nn_linear(128, 10)
    },
    forward = function(x) {
      x <- self$conv1(x)
      x <- nnf_relu(x)
      x <- self$conv2(x)
      x <- nnf_relu(x)
      x <- nnf_max_pool2d(x, 2)
      x <- self$dropout1(x)
      x <- torch_flatten(x, start_dim = 2)
      x <- self$fc1(x)
      x <- nnf_relu(x)
      x <- self$dropout2(x)
      x <- self$fc2(x)
      output <- nnf_log_softmax(x, dim = 1)
      output
    }
  )

  net <- Net()
  net$eval()

  # currently we need to detach all parameters in order to
  # JIT compile. We need to support modules to avoid that.
  for (p in net$parameters) {
    p$detach_()
  }

  fn <- function(x) {
    net(x)
  }

  input <- torch_randn(100, 1, 28, 28)
  out <- fn(input)

  tr_fn <- jit_trace(fn, input)
  expect_true(torch_allclose(fn(input), tr_fn(input)))
})

test_that("can save and reload", {
  fn <- function(x) {
    torch_relu(x)
  }

  input <- torch_tensor(c(-1, 0, 1))
  tr_fn <- jit_trace(fn, input)

  tmp <- tempfile("tst", fileext = "pt")
  jit_save(tr_fn, tmp)

  f <- jit_load(tmp)
  expect_equal_to_tensor(f(input), fn(input))
})

test_that("errors gracefully when passing unsupported inputs", {
  fn <- function(x) {
    torch_relu(x)
  }

  expect_error(
    jit_trace(fn, "a")
  )
})

test_that("can take lists of tensors as input", {
  fn <- function(x) {
    torch_stack(x)
  }
  x <- list(torch_tensor(1), torch_tensor(2))

  tr_fn <- jit_trace(fn, x)
  expect_equal_to_tensor(fn(x), tr_fn(x))
})

test_that("can output a list of tensors", {
  fn <- function(x) {
    list(x, x + 1)
  }
  x <- torch_tensor(1)
  tr_fn <- jit_trace(fn, x)
  expect_equal_to_tensor(fn(x)[[1]], tr_fn(x)[[1]])
  expect_equal_to_tensor(fn(x)[[2]], tr_fn(x)[[2]])
})

test_that("fn can take more than 1 argument", {
  fn <- function(x, y) {
    list(x, x + y)
  }

  x <- torch_tensor(1)
  y <- torch_tensor(2)

  tr_fn <- jit_trace(fn, x, y)
  expect_equal_to_tensor(fn(x, y)[[1]], tr_fn(x, y)[[1]])
  expect_equal_to_tensor(fn(x, y)[[2]], tr_fn(x, y)[[2]])

  expect_error(
    tr_fn <- jit_trace(fn, x = x, y = y)
  )
})

test_that("can have named inputs and outputs", {
  fn <- function(x) {
    list(x = x$t1, y = x$t2)
  }

  x <- list(
    t1 = torch_tensor(1),
    t2 = torch_tensor(2)
  )

  tr_fn <- jit_trace(fn, x, strict = FALSE)

  expect_equal(
    tr_fn(x),
    fn(x)
  )
})

test_that("tuple inputs are correctly handled", {
  fn <- function(x) {
    jit_tuple(list(x = x$t1, y = x$t2))
  }

  x <- jit_tuple(list(
    t1 = torch_tensor(1),
    t2 = torch_tensor(2)
  ))

  tr_fn <- jit_trace(fn, x, strict = FALSE)

  # returned named tuples will loose their names
  expect_equal(tr_fn(x), list(torch_tensor(1), torch_tensor(2)))

  # if the model has been traced with a named tuple we
  # will expect a tuple back too.
  expect_error(
    tr_fn(list(t1 = torch_tensor(1), t2 = torch_tensor(2)))
  )
})

test_that("tuple casting", {
  fn <- function(y) {
    jit_tuple(list(y[[1]], y[[2]]))
  }

  x <- jit_tuple(list(torch_tensor(1), torch_tensor(2)))
  tr_fn <- jit_trace(fn, x)

  expect_error(
    tr_fn(list(torch_tensor(1), torch_tensor(2)))
  )

  expect_equal(
    tr_fn(jit_tuple(list(torch_tensor(1), torch_tensor(2)))),
    list(torch_tensor(1), torch_tensor(2))
  )
})

test_that("trace a nn module", {
  test_module <- nn_module(
    initialize = function() {
      self$linear <- nn_linear(10, 10)
      self$norm <- nn_batch_norm1d(10)
      self$par <- nn_parameter(torch_tensor(2))
      self$buff <- nn_buffer(torch_randn(10, 5))
      self$constant <- 1
      self$hello <- list(torch_tensor(1), torch_tensor(2), "hello")
    },
    forward = function(x) {
      self$par * x
    },
    testing = function(x) {
      x %>%
        self$linear()
    },
    test_constant = function(x) {
      x + self$constant + self$hello[[2]]
    }
  )

  mod <- test_module()

  expect_error(
    m <- jit_trace_module(
      mod,
      forward = torch_randn(1),
      testing = list(torch_randn(10, 10)),
      test_constant = list(torch_tensor(1))
    ),
    regexp = NA
  )

  expect_length(m$parameters, 5)
  expect_length(m$buffers, 4)
  expect_length(m$modules, 3)

  expect_equal_to_tensor(m(torch_tensor(2)), torch_tensor(4))
  with_no_grad(m$par$zero_())
  expect_equal_to_tensor(m(torch_tensor(2)), torch_tensor(0))

  x <- torch_randn(10, 10)
  expect_true(torch_allclose(m$testing(x), mod$testing(x)))
  with_no_grad({
    m$linear$weight$zero_()$add_(1)
    mod$linear$weight$zero_()$add_(1)
  })
  expect_true(torch_allclose(m$testing(x), mod$testing(x)))

  expect_true(torch_allclose(m$test_constant(torch_tensor(2)), torch_tensor(5)))
})

test_that("dont crash when gcing a method", {
  mod <- jit_trace(nn_linear(10, 10), torch_randn(10, 10))
  gc()
  forward <- mod$forward
  rm(forward)
  gc()
  gc()
  expect_error(regexp = NA, mod$forward)
})

test_that("we can save traced modules", {
  test_module <- nn_module(
    initialize = function() {
      self$linear <- nn_linear(10, 10)
      self$norm <- nn_batch_norm1d(10)
      self$par <- nn_parameter(torch_tensor(2))
      self$buff <- nn_buffer(torch_randn(10, 5))
      self$constant <- 1
      self$hello <- list(torch_tensor(1), torch_tensor(2), "hello")
    },
    forward = function(x) {
      self$par * x
    },
    testing = function(x) {
      x %>%
        self$linear()
    },
    test_constant = function(x) {
      x + self$constant + self$hello[[2]]
    }
  )

  mod <- test_module()
  m <- jit_trace_module(
    mod,
    forward = torch_randn(1),
    testing = list(torch_randn(10, 10)),
    test_constant = list(torch_tensor(1))
  )

  jit_save(m, "tracedmodule.pt")
  rm(m)
  gc()
  gc()

  m <- jit_load("tracedmodule.pt")

  expect_length(m$parameters, 5)
  expect_length(m$buffers, 4)
  expect_length(m$modules, 3)

  expect_equal_to_tensor(m(torch_tensor(2)), torch_tensor(4))
  with_no_grad(m$par$zero_())
  expect_equal_to_tensor(m(torch_tensor(2)), torch_tensor(0))

  x <- torch_randn(10, 10)
  expect_true(torch_allclose(m$testing(x), mod$testing(x)))
  with_no_grad({
    m$linear$weight$zero_()$add_(1)
    mod$linear$weight$zero_()$add_(1)
  })
  expect_true(torch_allclose(m$testing(x), mod$testing(x), atol=1e-5))

  expect_true(torch_allclose(m$test_constant(torch_tensor(2)), torch_tensor(5)))
})

test_that("trace a module", {
  Net <- nn_module(
    "Net",
    initialize = function() {
      self$conv1 <- nn_conv2d(1, 32, 3, 1)
      self$conv2 <- nn_conv2d(32, 64, 3, 1)
      self$dropout1 <- nn_dropout2d(0.25)
      self$dropout2 <- nn_dropout2d(0.5)
      self$fc1 <- nn_linear(9216, 128)
      self$fc2 <- nn_linear(128, 10)
    },
    forward = function(x) {
      x <- self$conv1(x)
      x <- nnf_relu(x)
      x <- self$conv2(x)
      x <- nnf_relu(x)
      x <- nnf_max_pool2d(x, 2)
      x <- self$dropout1(x)
      x <- torch_flatten(x, start_dim = 2)
      x <- self$fc1(x)
      x <- nnf_relu(x)
      x <- self$dropout2(x)
      x <- self$fc2(x)
      output <- nnf_log_softmax(x, dim = 1)
      output
    }
  )

  net <- Net()
  net$eval()

  input <- torch_randn(100, 1, 28, 28)
  out <- net(input)

  tr_fn <- jit_trace(net, input)
  expect_equal_to_tensor(net(input), tr_fn(input), tolerance = 1e-6)
})

test_that("Can recover from errors in the traced method", {
  module <- nn_module(
    initialize = function() {},
    forward = function(x) {
      stop("The error abcde")
    }
  )

  expect_error(
    jit_trace(module(), torch_tensor(1)),
    regexp = ".*abcde"
  )

  expect_error(
    regexp = "You must initialize the nn_module before tracing",
    jit_trace(module, torch_tensor(1))
  )

  expect_error(
    regexp = "jit_trace needs a function or nn_module",
    jit_trace(1, torch_tensor(1))
  )
})

test_that("we get a good error message when trying to call a method from a submodule", {
  module <- nn_module(
    initialize = function() {
      self$linear <- nn_linear(10, 10)
    },
    forward = function(x) {
      self$linear(x)
    }
  )

  m <- jit_trace(module(), torch_randn(100, 10))

  expect_error(
    m$linear(torch_randn(10, 10)),
    regexp = "Methods from submodules of traced modules are not traced"
  )
})

test_that("errors in the tracer are correctly captured", {
  module <- nn_module(
    initialize = function() {
      self$linear <- nn_linear(10, 10)
    },
    forward = function(x) {
      self$linear(x)
      1
    }
  )

  expect_error(
    jit_trace(module(), torch_randn(10, 10)),
    regexp = ".*Only tensors, lists, tuples of tensors"
  )
})

test_that("we can include traced module as a submodule and trace", {
  module <- nn_module(
    initialize = function() {
      self$linear <- jit_trace(nn_linear(10, 10), torch_randn(10, 10))
    },
    forward = function(x) {
      self$linear(x)
    }
  )
  mod <- module()

  m <- jit_trace(mod, torch_randn(10, 10))

  x <- torch_randn(10, 10)
  expect_equal_to_tensor(m(x), mod(x))
  expect_equal_to_tensor(m$linear(x), mod$linear(x))
})

test_that("can save module for mobile", {
  Net <- nn_module(
    "Net",
    initialize = function() {
      self$conv1 <- nn_conv2d(1, 32, 3, 1)
      self$conv2 <- nn_conv2d(32, 64, 3, 1)
      self$dropout1 <- nn_dropout2d(0.25)
      self$dropout2 <- nn_dropout2d(0.5)
      self$fc1 <- nn_linear(9216, 128)
      self$fc2 <- nn_linear(128, 10)
    },
    forward = function(x) {
      x <- self$conv1(x)
      x <- nnf_relu(x)
      x <- self$conv2(x)
      x <- nnf_relu(x)
      x <- nnf_max_pool2d(x, 2)
      x <- self$dropout1(x)
      x <- torch_flatten(x, start_dim = 2)
      x <- self$fc1(x)
      x <- nnf_relu(x)
      x <- self$dropout2(x)
      x <- self$fc2(x)
      output <- nnf_log_softmax(x, dim = 1)
      output
    }
  )

  net <- Net()
  net$eval()

  input <- torch_randn(100, 1, 28, 28)
  out <- net(input)

  tr_fn <- jit_trace(net, input)

  tmp <- tempfile("tst", fileext = ".pt")
  jit_save_for_mobile(tr_fn, tmp)

  f <- jit_load(tmp)
  expect_equal_to_tensor(net(input), f(input), tol = 1e-6)
})

test_that("can save function for mobile", {
  fn <- function(x) {
    torch_relu(x)
  }

  input <- torch_tensor(c(-1, 0, 1))
  tr_fn <- jit_trace(fn, input)

  tmp <- tempfile("tst", fileext = ".pt")
  jit_save_for_mobile(tr_fn, tmp)

  f <- jit_load(tmp)
  expect_equal_to_tensor(torch_relu(input), f(input))
})
